{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wJcYs_ERTnnI"
      },
      "source": [
        "##### Copyright 2021 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "HMUDt0CiUJk9"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "77z2OchJTk0l"
      },
      "source": [
        "# Migrate early stopping\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/migrate/early_stopping\">\n",
        "    <img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />\n",
        "    View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/migrate/early_stopping.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />\n",
        "    Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/early_stopping.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />\n",
        "    View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/migrate/early_stopping.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "meUTrR4I6m1C"
      },
      "source": [
        "This notebook demonstrates how you can set up model training with early stopping, first, in TensorFlow 1 with `tf.estimator.Estimator` and an early stopping hook, and then, in TensorFlow 2 with Keras APIs or a custom training loop. Early stopping is a regularization technique that stops training if, for example, the validation loss reaches a certain threshold.\n",
        "\n",
        "In TensorFlow 2, there are three ways to implement early stopping:\n",
        "- Use a built-in Keras callback—`tf.keras.callbacks.EarlyStopping`—and pass it to `Model.fit`.\n",
        "- Define a custom callback and pass it to Keras `Model.fit`.\n",
        "- Write a custom early stopping rule in a [custom training loop](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch) (with `tf.GradientTape`)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YdZSoIXEbhg-"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iE0vSfMXumKI"
      },
      "outputs": [],
      "source": [
        "import time\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow.compat.v1 as tf1\n",
        "import tensorflow_datasets as tfds"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4uXff1BEssdE"
      },
      "source": [
        "## TensorFlow 1: Early stopping with an early stopping hook and tf.estimator"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JaHhhhW5o8lL"
      },
      "source": [
        "Start by defining functions for MNIST dataset loading and preprocessing, and model definition to be used with `tf.estimator.Estimator`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lqe9obf7suIj"
      },
      "outputs": [],
      "source": [
        "def normalize_img(image, label):\n",
        "  return tf.cast(image, tf.float32) / 255., label\n",
        "\n",
        "def _input_fn():\n",
        "  ds_train = tfds.load(\n",
        "    name='mnist',\n",
        "    split='train',\n",
        "    shuffle_files=True,\n",
        "    as_supervised=True)\n",
        "\n",
        "  ds_train = ds_train.map(\n",
        "      normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n",
        "  ds_train = ds_train.batch(128)\n",
        "  ds_train = ds_train.repeat(100)\n",
        "  return ds_train\n",
        "\n",
        "def _eval_input_fn():\n",
        "  ds_test = tfds.load(\n",
        "    name='mnist',\n",
        "    split='test',\n",
        "    shuffle_files=True,\n",
        "    as_supervised=True)\n",
        "  ds_test = ds_test.map(\n",
        "    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n",
        "  ds_test = ds_test.batch(128)\n",
        "  return ds_test\n",
        "\n",
        "def _model_fn(features, labels, mode):\n",
        "  flatten = tf1.layers.Flatten()(features)\n",
        "  features = tf1.layers.Dense(128, 'relu')(flatten)\n",
        "  logits = tf1.layers.Dense(10)(features)\n",
        "\n",
        "  loss = tf1.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)\n",
        "  optimizer = tf1.train.AdagradOptimizer(0.005)\n",
        "  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())\n",
        "\n",
        "  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hC_AY7KwqD0p"
      },
      "source": [
        "In TensorFlow 1, early stopping works by setting up an early stopping hook with `tf.estimator.experimental.make_early_stopping_hook`. You pass the hook to the `make_early_stopping_hook` method as a parameter for `should_stop_fn`, which can accept a function without any arguments. The training stops once `should_stop_fn` returns `True`.\n",
        "\n",
        "The following example demonstrates how to implement an early stopping technique that limits the training time to a maximum of 20 seconds:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HsOpjW5plH9Q"
      },
      "outputs": [],
      "source": [
        "estimator = tf1.estimator.Estimator(model_fn=_model_fn)\n",
        "\n",
        "start_time = time.time()\n",
        "max_train_seconds = 20\n",
        "\n",
        "def should_stop_fn():\n",
        "  return time.time() - start_time > max_train_seconds\n",
        "\n",
        "early_stopping_hook = tf1.estimator.experimental.make_early_stopping_hook(\n",
        "    estimator=estimator,\n",
        "    should_stop_fn=should_stop_fn,\n",
        "    run_every_secs=1,\n",
        "    run_every_steps=None)\n",
        "\n",
        "train_spec = tf1.estimator.TrainSpec(\n",
        "    input_fn=_input_fn,\n",
        "    hooks=[early_stopping_hook])\n",
        "\n",
        "eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)\n",
        "\n",
        "tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KEmzBjfnsxwT"
      },
      "source": [
        "### TensorFlow 2: Early stopping with a built-in callback and Model.fit"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GKwxnkIksPFW"
      },
      "source": [
        "Prepare the MNIST dataset and a simple Keras model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "atVciNgPs0fw"
      },
      "outputs": [],
      "source": [
        "(ds_train, ds_test), ds_info = tfds.load(\n",
        "    'mnist',\n",
        "    split=['train', 'test'],\n",
        "    shuffle_files=True,\n",
        "    as_supervised=True,\n",
        "    with_info=True,\n",
        ")\n",
        "\n",
        "ds_train = ds_train.map(\n",
        "    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n",
        "ds_train = ds_train.batch(128)\n",
        "\n",
        "ds_test = ds_test.map(\n",
        "    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n",
        "ds_test = ds_test.batch(128)\n",
        "\n",
        "model = tf.keras.models.Sequential([\n",
        "  tf.keras.layers.Flatten(input_shape=(28, 28)),\n",
        "  tf.keras.layers.Dense(128, activation='relu'),\n",
        "  tf.keras.layers.Dense(10)\n",
        "])\n",
        "\n",
        "model.compile(\n",
        "    optimizer=tf.keras.optimizers.Adam(0.005),\n",
        "    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "559Goxp3tOMl"
      },
      "source": [
        "In TensorFlow 2, when you use the built-in Keras `Model.fit` (or `Model.evaluate`), you can configure early stopping by passing a built-in callback—`tf.keras.callbacks.EarlyStopping`—to the `callbacks` parameter of `Model.fit`.\n",
        "\n",
        "The `EarlyStopping` callback monitors a user-specified metric and ends training when it stops improving. (Check the [Training and evaluation with the built-in methods](https://www.tensorflow.org/guide/keras/train_and_evaluate#using_callbacks) or the [API docs](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/EarlyStopping) for more information.)\n",
        "\n",
        "Below is an example of an early stopping callback that monitors the loss and stops training after the number of epochs that show no improvements is set to `3` (`patience`): "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Kip65sYBlKiu"
      },
      "outputs": [],
      "source": [
        "callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)\n",
        "\n",
        "# Only around 25 epochs are run during training, instead of 100.\n",
        "history = model.fit(\n",
        "    ds_train,\n",
        "    epochs=100,\n",
        "    validation_data=ds_test,\n",
        "    callbacks=[callback]\n",
        ")\n",
        "\n",
        "len(history.history['loss'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a92c6ebb1a1c"
      },
      "source": [
        "### TensorFlow 2: Early stopping with a custom callback and Model.fit"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wCwZ4BA8jaHY"
      },
      "source": [
        "You can also implement a [custom early stopping callback](https://www.tensorflow.org/guide/keras/custom_callback/#early_stopping_at_minimum_loss), which can also be passed to the `callbacks` parameter of `Model.fit` (or `Model.evaluate`).\n",
        "\n",
        "In this example, the training process is stopped once `self.model.stop_training` is set to be `True`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Hns1fmwtjCg2"
      },
      "outputs": [],
      "source": [
        "class LimitTrainingTime(tf.keras.callbacks.Callback):\n",
        "  def __init__(self, max_time_s):\n",
        "    super().__init__()\n",
        "    self.max_time_s = max_time_s\n",
        "    self.start_time = None\n",
        "\n",
        "  def on_train_begin(self, logs):\n",
        "    self.start_time = time.time()\n",
        "\n",
        "  def on_train_batch_end(self, batch, logs):\n",
        "    now = time.time()\n",
        "    if now - self.start_time >  self.max_time_s:\n",
        "      self.model.stop_training = True"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s5mIzDOAkUKA"
      },
      "outputs": [],
      "source": [
        "# Limit the training time to 30 seconds.\n",
        "callback = LimitTrainingTime(30)\n",
        "history = model.fit(\n",
        "    ds_train,\n",
        "    epochs=100,\n",
        "    validation_data=ds_test,\n",
        "    callbacks=[callback]\n",
        ")\n",
        "len(history.history['loss'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kro_lKyEu60-"
      },
      "source": [
        "## TensorFlow 2: Early stopping with a custom training loop"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g5LU0lebvuIk"
      },
      "source": [
        "In TensorFlow 2, you can implement early stopping in a [custom training loop](https://www.tensorflow.org/tutorials/customization/custom_training_walkthrough#training_loop) if you're not training and evaluating with the [built-in Keras methods](https://www.tensorflow.org/guide/keras/train_and_evaluate).\n",
        "\n",
        "Start by using Keras APIs to define another simple model, an optimizer, a loss function, and metrics:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oTGxr0PwAiQ4"
      },
      "outputs": [],
      "source": [
        "model = tf.keras.models.Sequential([\n",
        "  tf.keras.layers.Flatten(input_shape=(28, 28)),\n",
        "  tf.keras.layers.Dense(128, activation='relu'),\n",
        "  tf.keras.layers.Dense(10)\n",
        "])\n",
        "\n",
        "optimizer = tf.keras.optimizers.Adam(0.005)\n",
        "loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
        "\n",
        "train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()\n",
        "train_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()\n",
        "val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()\n",
        "val_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zecsnqRxvy0Q"
      },
      "source": [
        "Define the parameter update functions [with tf.GradientTape](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch) and the `@tf.function` decorator [for a speedup](https://www.tensorflow.org/guide/function):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s3w_55n0Ah7L"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def train_step(x, y):\n",
        "  with tf.GradientTape() as tape:\n",
        "      logits = model(x, training=True)\n",
        "      loss_value = loss_fn(y, logits)\n",
        "  grads = tape.gradient(loss_value, model.trainable_weights)\n",
        "  optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
        "  train_acc_metric.update_state(y, logits)\n",
        "  train_loss_metric.update_state(y, logits)\n",
        "  return loss_value\n",
        "\n",
        "@tf.function\n",
        "def test_step(x, y):\n",
        "  logits = model(x, training=False)\n",
        "  val_acc_metric.update_state(y, logits)\n",
        "  val_loss_metric.update_state(y, logits)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-ZKS9ePGwd9r"
      },
      "source": [
        "Next, write a custom training loop, where you can implement your early stopping rule manually.\n",
        "\n",
        "The example below shows how to stop training when the validation loss doesn't improve over a certain number of epochs:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iZOzHqqSAkpK"
      },
      "outputs": [],
      "source": [
        "epochs = 100\n",
        "patience = 5\n",
        "wait = 0\n",
        "best = float('inf')\n",
        "\n",
        "for epoch in range(epochs):\n",
        "    print(\"\\nStart of epoch %d\" % (epoch,))\n",
        "    start_time = time.time()\n",
        "\n",
        "    for step, (x_batch_train, y_batch_train) in enumerate(ds_train):\n",
        "      loss_value = train_step(x_batch_train, y_batch_train)\n",
        "      if step % 200 == 0:\n",
        "        print(\"Training loss at step %d: %.4f\" % (step, loss_value.numpy()))\n",
        "        print(\"Seen so far: %s samples\" % ((step + 1) * 128))        \n",
        "    train_acc = train_acc_metric.result()\n",
        "    train_loss = train_loss_metric.result()\n",
        "    train_acc_metric.reset_states()\n",
        "    train_loss_metric.reset_states()\n",
        "    print(\"Training acc over epoch: %.4f\" % (train_acc.numpy()))\n",
        "\n",
        "    for x_batch_val, y_batch_val in ds_test:\n",
        "      test_step(x_batch_val, y_batch_val)\n",
        "    val_acc = val_acc_metric.result()\n",
        "    val_loss = val_loss_metric.result()\n",
        "    val_acc_metric.reset_states()\n",
        "    val_loss_metric.reset_states()\n",
        "    print(\"Validation acc: %.4f\" % (float(val_acc),))\n",
        "    print(\"Time taken: %.2fs\" % (time.time() - start_time))\n",
        "\n",
        "    # The early stopping strategy: stop the training if `val_loss` does not\n",
        "    # decrease over a certain number of epochs.\n",
        "    wait += 1\n",
        "    if val_loss < best:\n",
        "      best = val_loss\n",
        "      wait = 0\n",
        "    if wait >= patience:\n",
        "      break"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e85558980a4b"
      },
      "source": [
        "## Next steps\n",
        "\n",
        "- Learn more about the Keras built-in early stopping callback API in the [API docs](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/EarlyStopping).\n",
        "- Learn to [write custom Keras callbacks](https://www.tensorflow.org/guide/keras/custom_callback), including [early stopping at a minimum loss](https://www.tensorflow.org/guide/keras/custom_callback/#early_stopping_at_minimum_loss).\n",
        "- Learn about [Training and evaluation with the Keras built-in methods](https://www.tensorflow.org/guide/keras/train_and_evaluate#using_callbacks).\n",
        "- Explore common regularization techniques in the [Overfit and underfit](tensorflow.org/tutorials/keras/overfit_and_underfit) tutorial that uses the `EarlyStopping` callback."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "early_stopping.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
